{%- set name = "dist-strat-example" -%}
{%- set image = "gcr.io/<your_project>/keras_model_to_estimator:v1" -%}
{%- set worker_replicas = 2 -%}
{%- set ps_replicas = 0 -%}
{%- set num_gpus_per_worker = 2 -%}
{%- set has_eval = False -%}
{%- set has_tensorboard = False -%}
{%- set train_dir = "gs://<your_gcs_bucket>" -%}
{%- set script = "/keras_model_to_estimator.py" -%}
{%- set cmdline_args = train_dir -%}
{%- set credential_secret_json = "key.json" -%}
{%- set credential_secret_key = "credential" -%}
{%- set port = 5000 -%}


{%- set replicas = {"worker": worker_replicas,
                    "ps": ps_replicas,
                    "evaluator": has_eval|int,
                    "tensorboard": has_tensorboard|int} -%}
{% set cmdline_arg_list = cmdline_args.split(" ") %}

{%- macro worker_hosts() -%}
  {%- for i in range(worker_replicas) -%}
    {%- if not loop.first -%},{%- endif -%}
    \"{{ name }}-worker-{{ i }}:{{ port }}\"
  {%- endfor -%}
{%- endmacro -%}

{%- macro ps_hosts() -%}
  {%- for i in range(ps_replicas) -%}
    {%- if not loop.first -%},{%- endif -%}
    \"{{ name }}-ps-{{ i }}:{{ port }}\"
  {%- endfor -%}
{%- endmacro -%}

{%- macro tf_config(task_type, task_id) -%}
{
  \"cluster\": {
    \"worker\": [{{ worker_hosts() }}]
    {%- if ps_replicas > 0 -%}, \"ps\": [{{ ps_hosts() }}]{%- endif -%}
    {%- if has_eval -%},
    \"evaluator\": [\"{{ name }}-evaluator-0:{{ port }}\"]{%- endif -%}
  },
  \"task\": {
    \"type\": \"{{ task_type }}\",
    \"index\": \"{{ task_id }}\"
  }
}
{%- endmacro -%}

{% for job in ["worker", "ps", "evaluator", "tensorboard"] -%}
{%- for i in range(replicas[job]) -%}
kind: Service
apiVersion: v1
metadata:
  name: {{ name }}-{{ job }}-{{ i }}
spec:
  type: LoadBalancer
  selector:
    name: {{ name }}
    job: {{ job }}
    task: "{{ i }}"
  ports:
  - port: {{ port }}
---
kind: ReplicationController
apiVersion: v1
metadata:
  name: {{ name }}-{{ job }}-{{ i }}
spec:
  replicas: 1
  template:
    metadata:
      labels:
        name: {{ name }}
        job: {{ job }}
        task: "{{ i }}"
    spec:
      containers:
{% if job == "tensorboard" %}
      - name: tensorflow
        image: tensorflow/tensorflow
{% else %}
      - name: tensorflow
        image: {{ image }}
        resources:
          limits:
            nvidia.com/gpu: {{ num_gpus_per_worker }}
{% endif %}
        env:
{% if job != "tensorboard" %}
        - name: TF_CONFIG
          value: "{{ tf_config(job, i) }}"
{% endif %}
        - name: GOOGLE_APPLICATION_CREDENTIALS
          value: "/var/secrets/google/{{ credential_secret_json }}"
        ports:
        - containerPort: {{ port }}
{% if job == "tensorboard" %}
        command:
        - "tensorboard"
        args:
        - "--logdir={{ train_dir }}"
        - "--port={{ port }}"
{% else %}
        command:
        - "/usr/bin/python"
        - "{{ script }}"
        {%- for cmdline_arg in cmdline_arg_list %}
        - "{{ cmdline_arg }}"
        {%- endfor -%}
{% endif %}
        volumeMounts:
        - name: credential
          mountPath: /var/secrets/google
      volumes:
      - name: credential
        secret:
          secretName: {{ credential_secret_key }}
---
{% endfor %}
{%- endfor -%}
